{
  "metadata": {
    "dataExplorerConfig": {},
    "bento_stylesheets": {
      "bento/extensions/flow/main.css": true,
      "bento/extensions/kernel_selector/main.css": true,
      "bento/extensions/kernel_ui/main.css": true,
      "bento/extensions/new_kernel/main.css": true,
      "bento/extensions/system_usage/main.css": true,
      "bento/extensions/theme/main.css": true
    },
    "kernelspec": {
      "display_name": "accelerators",
      "language": "python",
      "name": "bento_kernel_accelerators",
      "metadata": {
        "kernel_name": "bento_kernel_accelerators",
        "nightly_builds": true,
        "fbpkg_supported": true,
        "cinder_runtime": false,
        "is_prebuilt": true
      }
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3"
    },
    "last_server_session_id": "42b65868-6af0-4f04-bf2f-b7e2511f23dd",
    "last_kernel_id": "a08a7dfc-0fcc-4486-a2d5-604483260888",
    "last_base_url": "https://devgpu005.ftw6.facebook.com:8093/",
    "last_msg_id": "3f4cd9a4-65001843cf56aec954e05889_80",
    "outputWidgetContext": {}
  },
  "nbformat": 4,
  "nbformat_minor": 2,
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "aac0295c-e26e-45cb-b1b6-7796ee860152",
        "showInput": false,
        "customInput": null,
        "code_folding": [],
        "hidden_ranges": []
      },
      "source": [
        "The purpose of this example is to demonstrate the overall flow of lowering a PyTorch\n",
        "model to TensorRT via FX with existing FX based tooling. The general lowering flow would be like:\n",
        "1. Use splitter to split the model if there're ops in the model that we don't want to lower to TensorRT for some reasons like the ops are not supported in TensorRT or running them on other backends provides better performance.\n",
        "2. Lower the model (or part of the model if splitter is used) to TensorRT via fx path.\n",
        "If we know the model is fully supported by fx path (without op unsupported) then we can skip the splitter."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "ca68b029-68a6-42d6-968e-95bb7c1aae73",
        "showInput": true,
        "customInput": null,
        "collapsed": false,
        "requestMsgId": "f56944ff-ade2-4041-bdd6-3bce44b1405f",
        "customOutput": null,
        "executionStartTime": 1656367410991,
        "executionStopTime": 1656367412604
      },
      "source": [
        "import torch\n",
        "import torch.fx\n",
        "import torch.nn as nn\n",
        "from torch_tensorrt.fx.utils import LowerPrecision\n",
        "import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer\n",
        "from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule\n",
        "from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter"
      ],
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "8f974ab2-d187-4ffe-a09b-16cd85949be4",
        "showInput": true,
        "customInput": null,
        "code_folding": [],
        "hidden_ranges": [],
        "collapsed": false,
        "requestMsgId": "564359f5-ac69-4666-91e1-41b299495ed1",
        "customOutput": null,
        "executionStartTime": 1656367414494,
        "executionStopTime": 1656367422756
      },
      "source": [
        "class Model(nn.Module):\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "        self.linear = nn.Linear(10, 10)\n",
        "        self.relu = nn.ReLU()\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.linear(x)\n",
        "        x = self.relu(x)\n",
        "        x = torch.linalg.norm(x, ord=2, dim=1)\n",
        "        x = self.relu(x)\n",
        "        return x\n",
        "\n",
        "\n",
        "inputs = [torch.randn((1, 10), device=torch.device('cuda'))]\n",
        "model = Model().cuda().eval()"
      ],
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "0d407e92-e9e7-48aa-9c9e-1c21a9b5fd8f",
        "showInput": false,
        "customInput": null,
        "code_folding": [],
        "hidden_ranges": []
      },
      "source": [
        "acc_tracer is a custom fx tracer that maps nodes whose targets are PyTorch operators\n",
        "to acc ops."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "a1d9c8c2-8ec7-425a-8518-6f7e53ab1e67",
        "showInput": true,
        "customInput": null,
        "code_folding": [],
        "hidden_ranges": [],
        "collapsed": false,
        "requestMsgId": "ee2da608-5f1c-4f63-9927-544717e84e8a",
        "customOutput": null,
        "executionStartTime": 1656367480626,
        "executionStopTime": 1656367482881
      },
      "source": [
        "traced = acc_tracer.trace(model, inputs)"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "246613eb-14b5-488e-9aae-35306fc99db1",
        "showInput": false,
        "customInput": null,
        "code_folding": [],
        "hidden_ranges": []
      },
      "source": [
        "Splitter will split the model into several submodules. The name of submodules will\n",
        "be either `run_on_acc_{}` or `run_on_gpu_{}`. Submodules named `run_on_acc_{}` can\n",
        "be fully lowered to TensorRT via fx2trt while submodules named `run_on_gpu_{}` has\n",
        "unsupported ops and can't be lowered by fx2trt. We can still run `run_on_gpu_{}`\n",
        "submodules on GPU if ops there have cuda implementation.\n",
        ""
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "1103c70a-3766-4d89-ad2f-cdcb1c3891e0",
        "showInput": true,
        "customInput": null,
        "collapsed": false,
        "requestMsgId": "feb888ea-ef9c-4577-b0c6-cf95bc1dd25e",
        "customOutput": null,
        "executionStartTime": 1656367487073,
        "executionStopTime": 1656367487154
      },
      "source": [
        "splitter = TRTSplitter(traced, inputs)"
      ],
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "3d65e07e-57ed-47d5-adb9-4685c69c9c6b",
        "showInput": false,
        "customInput": null,
        "code_folding": [],
        "hidden_ranges": []
      },
      "source": [
        "Preview functionality allows us to see what are the supported ops and unsupported\n",
        "ops. We can optionally the dot graph which will color supported ops and unsupported\n",
        "ops differently."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "6aaed2d5-61b7-438e-a72a-63f91d0709e2",
        "showInput": true,
        "customInput": null,
        "collapsed": false,
        "requestMsgId": "2948c2f8-854b-4bc2-b399-321469da320c",
        "customOutput": null,
        "executionStartTime": 1656367489373,
        "executionStopTime": 1656367489556
      },
      "source": [
        "splitter.node_support_preview()"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\nSupported node types in the model:\nacc_ops.linear: ((), {'input': torch.float32, 'weight': torch.float32, 'bias': torch.float32})\nacc_ops.relu: ((), {'input': torch.float32})\n\nUnsupported node types in the model:\nacc_ops.linalg_norm: ((), {'input': torch.float32})\n\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": "\"\\nSupported node types in the model:\\nacc_ops.linear: ((), {'input': torch.float32, 'weight': torch.float32, 'bias': torch.float32})\\nacc_ops.relu: ((), {'input': torch.float32})\\n\\nUnsupported node types in the model:\\nacc_ops.linalg_norm: ((), {'input': torch.float32})\\n\""
          },
          "metadata": {
            "bento_obj_id": "139812830161136"
          },
          "execution_count": 5
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "8d8035ab-869e-4096-b8e1-3539a0cfe1af",
        "showInput": false,
        "customInput": null
      },
      "source": [
        "After split, there are three submodules, _run_on_acc_0 and _run_on_gpu_1. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "80e03730-955a-4cc8-b071-7f92a2cff3df",
        "showInput": true,
        "customInput": null,
        "code_folding": [],
        "hidden_ranges": [],
        "collapsed": false,
        "requestMsgId": "2ca46574-7176-4699-a809-2a2e2d5ffda0",
        "customOutput": null,
        "executionStartTime": 1656367495077,
        "executionStopTime": 1656367495250
      },
      "source": [
        "split_mod = splitter()\n",
        "print(split_mod.graph)"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Got 2 acc subgraphs and 1 non-acc subgraphs\ngraph():\n    %x : [#users=1] = placeholder[target=x]\n    %_run_on_acc_0 : [#users=1] = call_module[target=_run_on_acc_0](args = (%x,), kwargs = {})\n    %_run_on_gpu_1 : [#users=1] = call_module[target=_run_on_gpu_1](args = (%_run_on_acc_0,), kwargs = {})\n    %_run_on_acc_2 : [#users=1] = call_module[target=_run_on_acc_2](args = (%_run_on_gpu_1,), kwargs = {})\n    return _run_on_acc_2\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "9ce75161-978e-468e-9989-ecdbc9af0d5b",
        "showInput": true,
        "customInput": null,
        "collapsed": false,
        "requestMsgId": "0370de27-39ec-4be0-826b-9aec90df1155",
        "customOutput": null,
        "executionStartTime": 1656367496353,
        "executionStopTime": 1656367496452
      },
      "source": [
        "print(split_mod._run_on_acc_0.graph)\n",
        "print(split_mod._run_on_gpu_1.graph)\n",
        "print(split_mod._run_on_acc_2.graph)"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "graph():\n    %x : [#users=1] = placeholder[target=x]\n    %linear_weight : [#users=1] = get_attr[target=linear.weight]\n    %linear_bias : [#users=1] = get_attr[target=linear.bias]\n    %linear_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.linear](args = (), kwargs = {input: %x, weight: %linear_weight, bias: %linear_bias})\n    %relu_2 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.relu](args = (), kwargs = {input: %linear_1, inplace: False})\n    return relu_2\ngraph():\n    %relu_2 : [#users=1] = placeholder[target=relu_2]\n    %linalg_norm_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.linalg_norm](args = (), kwargs = {input: %relu_2, ord: 2, dim: 1, keepdim: False})\n    return linalg_norm_1\ngraph():\n    %linalg_norm_1 : [#users=1] = placeholder[target=linalg_norm_1]\n    %relu_3 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.relu](args = (), kwargs = {input: %linalg_norm_1, inplace: False})\n    return relu_3\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "7a6857bc-fedd-4847-ba17-a5d114de34f3",
        "showInput": false,
        "customInput": null,
        "code_folding": [],
        "hidden_ranges": []
      },
      "source": [
        "The `split_mod` contains the child modules supported by TRT or eager gpu. We can iterate them to transform the module into TRT engine."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "174fd2eb-a864-49cf-a204-6d24a8e2849d",
        "showInput": true,
        "customInput": null,
        "collapsed": false,
        "requestMsgId": "cf7fdfe4-e781-47c8-9a9a-85b5664c10f7",
        "customOutput": null,
        "executionStartTime": 1656367502837,
        "executionStopTime": 1656367510024,
        "code_folding": [],
        "hidden_ranges": []
      },
      "source": [
        "def get_submod_inputs(mod, submod, inputs):\n",
        "    acc_inputs = None\n",
        "\n",
        "    def get_input(self, inputs):\n",
        "        nonlocal acc_inputs\n",
        "        acc_inputs = inputs\n",
        "\n",
        "    handle = submod.register_forward_pre_hook(get_input)\n",
        "    mod(*inputs)\n",
        "    handle.remove()\n",
        "    return acc_inputs\n",
        "\n",
        "# Since the model is splitted into three segments. We need to lower each TRT eligible segment.\n",
        "# If we know the model can be fully lowered, we can skip the splitter part.\n",
        "for name, _ in split_mod.named_children():\n",
        "    if \"_run_on_acc\" in name:\n",
        "        submod = getattr(split_mod, name)\n",
        "        # Get submodule inputs for fx2trt\n",
        "        acc_inputs = get_submod_inputs(split_mod, submod, inputs)\n",
        "\n",
        "        # fx2trt replacement\n",
        "        interp = TRTInterpreter(\n",
        "            submod,\n",
        "            InputTensorSpec.from_tensors(acc_inputs),\n",
        "            explicit_batch_dimension=True,\n",
        "        )\n",
        "        r = interp.run(lower_precision=LowerPrecision.FP32)\n",
        "        trt_mod = TRTModule(*r)\n",
        "        setattr(split_mod, name, trt_mod)\n",
        "\n",
        "lowered_model_output = split_mod(*inputs)"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "I0627 150503.073 fx2trt.py:190] Run Module elapsed time: 0:00:00.014965\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "I0627 150504.996 fx2trt.py:241] Build TRT engine elapsed time: 0:00:01.922029\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "I0627 150505.026 fx2trt.py:190] Run Module elapsed time: 0:00:00.000302\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "I0627 150509.953 fx2trt.py:241] Build TRT engine elapsed time: 0:00:04.925192\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "f1db3e1e-3a70-4735-a403-baa557b0f8a6",
        "showInput": false,
        "customInput": null,
        "code_folding": [],
        "hidden_ranges": []
      },
      "source": [
        "Model can be saved by torch.save and loaded with torch.load. Then we can compare the results with eager mode inference. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "a7c4fa0f-cac6-4959-8fa6-13b3455137d3",
        "showInput": true,
        "customInput": null,
        "collapsed": false,
        "requestMsgId": "f0c264ac-2bda-4c8e-a236-e2bd475e601e",
        "customOutput": null,
        "executionStartTime": 1656367515833,
        "executionStopTime": 1656367516184
      },
      "source": [
        "torch.save(split_mod, \"trt.pt\")\n",
        "reload_trt_mod = torch.load(\"trt.pt\")\n",
        "reload_model_output = reload_trt_mod(*inputs)\n",
        "\n",
        "# Make sure the results match\n",
        "regular_model_output = model(*inputs)\n",
        "torch.testing.assert_close(\n",
        "    reload_model_output, regular_model_output, atol=3e-3, rtol=1e-2\n",
        ")"
      ],
      "execution_count": 9,
      "outputs": []
    }
  ]
}
